#ifndef STATSxx_MACHINE_LEARNING_DBN_HPP
#define STATSxx_MACHINE_LEARNING_DBN_HPP

// STL
//#include <utility> // std::pair<>, std::make_pair()
#include <vector>  // std::vector<>

// Boost
#include <boost/serialization/serialization.hpp> // boost::serialization::
#include <boost/serialization/vector.hpp>        // serialize std::vector<>

// jScience
#include "jScience/linalg.hpp" // Matrix<>, Vector<>

// stats++
#include "statsxx/machine_learning/RBM.hpp" // machine_learning::RBM


// TODO: there should probably be a parameters file that specifies training (with default values, etc.) (this applies to the RBM as well)


namespace machine_learning
{
    //=========================================================
    // DEEP BELIEF NETWORK
    //=========================================================
    //
    // NOTE: the DBN settings are directly analogous to the RBM
    //
    // NOTE: one principle of DBN is that the same (or directly analogous) subroutines from RBM are used, but things are wrapped in std::vector<>
    // NOTE: ... ex: see the two v_to_h() subroutines
    //
    class DBN
    {

    public:

        std::vector<machine_learning::RBM> RBM;


        DBN();

        DBN(
            const int              nRBM,
            // ---
            const std::vector<int> nv_,
            const std::vector<int> nh_,
            // ---
            const std::vector<int> vtype_,
            const std::vector<int> htype_
            );

        ~DBN();

        // NOTE: the calling sequence is identical to that in RBM
        void train(
                   const int                  nepoch,
                   const int                  nbatches,
                   // ---
                   const int                  K_start,            // starting number of CD steps
                   const int                  K_end,              // ending number
                   const double               K_rate,             // rate at which K_start transitions to K_end
                   // ---
                   const bool                 sample_v,           // whether to sample v during reconstructions --- this is correct and may lead to better density models (Hinton, practical guide), but not reduces sampling noise thus allowing faster learning
                   const bool                 sample_hdata,       // whether to sample h for the positive statistics --- true is closer to the mathematical model of an RBM
                   // ---
                   double                     lr,
                   const std::vector<double> &lr_dot,
                   const std::vector<double> &lr_adj,
                   const double               lr_min,
                   const double               lr_max,
                   // ---
                   const double               momentum_start,
                   const double               momentum_end,
                   const double               momentum_rate,
                   const std::vector<double> &momentum_dot,
                   const std::vector<double> &momentum_adj,
                   // ---
                   const double               weight_penalty,
                   // ---
                   const int                  free_energy_npts,   // number of training points to (randomly) select for free energy
                   const int                  free_energy_nepoch, // evaluate the free energy every number of epochs
                   const int                  free_energy_window, // calculate the slope of the free energy over this window
                   const double               convg_criterion_max_inc_max_w,
                   const int                  convg_criterion_nno_improvement,
                   // ---
                   Matrix<double>             X,                  // training set
                   Matrix<double>             X_val,              // validation set
                   // ---
                   const std::string          prefix
                   );

        std::vector<Vector<double>> v_to_h(
                                           const Vector<double> &v
                                           ) const;

        std::vector<Matrix<double>> v_to_h(
                                           const Matrix<double> &V
                                           ) const;

/*
        std::pair<
                  Vector<double>,
                  Vector<double>
                  > reconstruct(
                                const Vector<double> &v
                                ) const;

        std::pair<
                  Matrix<double>,
                  Matrix<double>
                  > reconstruct(
                                const Matrix<double> &V
                                ) const;

        double free_energy(
                           const Vector<double> &v
                           ) const;

        std::vector<double> free_energy(
                                        const Matrix<double> &V
                                        ) const;

        Matrix<double> get_W() const;
        Vector<double> get_a() const;
        Vector<double> get_b() const;

        // TODO: tmp because I wanted to assign easily ...
        // ... probably better to either provide subroutines to set_W, set_a, etc. --- or do this from a constructor
        Matrix<double> W;
        Vector<double> a;
        Vector<double> b;
*/

        std::vector<int> architecture() const;

        std::vector<Matrix<double>> get_W() const;
        std::vector<Vector<double>> get_a() const;
        std::vector<Vector<double>> get_b() const;

    private:

        // (Boost) SERIALIZATION
        friend class boost::serialization::access;

        template<class Archive>
        void serialize(
                       Archive            &ar,
                       const unsigned int  version
                       )
        {
            ar & this->RBM;
        }


/*
        std::vector<int> nv;    // number of visible units
        std::vector<int> nh;    // ... hidden

        std::vector<int> vtype; // 0 == binary; 1 == continuous; 2 == ReLU; 3 == softplus
        std::vector<int> htype; // ... same

        //Matrix<double> W;
        //Vector<double> a;
        //Vector<double> b;
*/
/*
        double free_energy_visible(
                                   const Vector<double> &v
                                   ) const;
        double free_energy_hidden(
                                  const Vector<double> &v
                                  ) const;
        double free_energy_const() const;

        Vector<double> calculate_p(
                                   const Vector<double> &x,
                                   const int             type
                                   );

        Vector<double> sample(
                              const Vector<double> &x,
                              const Vector<double> &px,
                              const int             type
                              );
*/
    };

}

#include "statsxx/machine_learning/DBN/architecture.cpp"
#include "statsxx/machine_learning/DBN/DBN.cpp"
#include "statsxx/machine_learning/DBN/get_W.cpp"
#include "statsxx/machine_learning/DBN/get_a.cpp"
#include "statsxx/machine_learning/DBN/get_b.cpp"
#include "statsxx/machine_learning/DBN/train.cpp"
#include "statsxx/machine_learning/DBN/v_to_h.cpp"


#endif
